""" Pytorch Inception-Resnet-V2 implementation
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
"""

from functools import partial
from typing import Any, Callable, List, Optional, Type, Union

import torch
import torch.nn as nn
from torch import Tensor

from torchvision.transforms._presets import ImageClassification
from torchvision.utils import _log_api_usage_once
from torchvision.models._api import  Weights, WeightsEnum
from timm.models.registry import register_model
from torchvision.models._meta import _IMAGENET_CATEGORIES
from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface

import torch
import torch.nn as nn
import torch.nn.functional as F

from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models.helpers import build_model_with_cfg
from timm.models.layers import create_classifier
from timm.models.registry import register_model

__all__ = ['GhostInceptionResnetV2']

default_cfgs = {
    # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
    'inception_resnet_v2': {
        'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
        'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
        'crop_pct': 0.8975, 'interpolation': 'bicubic',
        'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
        'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
        'label_offset': 1,  # 1001 classes in pretrained weights
    },
    # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
    'ens_adv_inception_resnet_v2': {
        'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
        'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
        'crop_pct': 0.8975, 'interpolation': 'bicubic',
        'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
        'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
        'label_offset': 1,  # 1001 classes in pretrained weights
    }
}


def uniform_random(shape, minval, maxval, device):
    return (maxval - minval) * torch.rand(shape, requires_grad=False, device=device) + minval


class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(
            in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_planes, eps=.001)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Mixed_5b(nn.Module):
    def __init__(self):
        super(Mixed_5b, self).__init__()

        self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(192, 48, kernel_size=1, stride=1),
            BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(192, 64, kernel_size=1, stride=1),
            BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
            BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
        )

        self.branch3 = nn.Sequential(
            nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
            BasicConv2d(192, 64, kernel_size=1, stride=1)
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class GhostBlock35(nn.Module):
    def __init__(self, ghost_random_range, scale=1.0):
        super(GhostBlock35, self).__init__()

        self.scale = scale
        self.ghost_random_range = ghost_random_range

        self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(320, 32, kernel_size=1, stride=1),
            BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(320, 32, kernel_size=1, stride=1),
            BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
            BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
        )

        self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        out = self.conv2d(out)
        c = x.shape[1]
        out = out * self.scale + x * uniform_random(shape=(1, c, 1, 1), minval=1-self.ghost_random_range, maxval=1+self.ghost_random_range, device=x.device)
        out = self.relu(out)
        return out


class Mixed_6a(nn.Module):
    def __init__(self):
        super(Mixed_6a, self).__init__()

        self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)

        self.branch1 = nn.Sequential(
            BasicConv2d(320, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
            BasicConv2d(256, 384, kernel_size=3, stride=2)
        )

        self.branch2 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        return out


class GhostBlock17(nn.Module):
    def __init__(self, ghost_random_range, scale=1.0):
        super(GhostBlock17, self).__init__()

        self.scale = scale
        self.ghost_random_range = ghost_random_range

        self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(1088, 128, kernel_size=1, stride=1),
            BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
            BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
        )

        self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        c = x.shape[1]
        out = out * self.scale + x * uniform_random(shape=(1, c, 1, 1), minval=1-self.ghost_random_range, maxval=1+self.ghost_random_range, device=x.device)
        out = self.relu(out)
        return out


class Mixed_7a(nn.Module):
    def __init__(self):
        super(Mixed_7a, self).__init__()

        self.branch0 = nn.Sequential(
            BasicConv2d(1088, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 384, kernel_size=3, stride=2)
        )

        self.branch1 = nn.Sequential(
            BasicConv2d(1088, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 288, kernel_size=3, stride=2)
        )

        self.branch2 = nn.Sequential(
            BasicConv2d(1088, 256, kernel_size=1, stride=1),
            BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
            BasicConv2d(288, 320, kernel_size=3, stride=2)
        )

        self.branch3 = nn.MaxPool2d(3, stride=2)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)
        return out


class GhostBlock8(nn.Module):

    def __init__(self, scale=1.0, no_relu=False):
        super(GhostBlock8, self).__init__()

        self.scale = scale

        self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)

        self.branch1 = nn.Sequential(
            BasicConv2d(2080, 192, kernel_size=1, stride=1),
            BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
            BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
        )

        self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
        self.relu = None if no_relu else nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        out = torch.cat((x0, x1), 1)
        out = self.conv2d(out)
        c = x.shape[1]
        out = out * self.scale + x * uniform_random(shape=(1, c, 1, 1), minval=1-self.ghost_random_range, maxval=1+self.ghost_random_range, device=x.device)
        if self.relu is not None:
            out = self.relu(out)
        return out


class GhostInceptionResnetV2(nn.Module):
    def __init__(self, ghost_random_range=0.08, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
        super(GhostInceptionResnetV2, self).__init__()
        self.drop_rate = drop_rate
        self.num_classes = num_classes
        self.num_features = 1536
        assert output_stride == 32

        self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
        self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
        self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]

        self.maxpool_3a = nn.MaxPool2d(3, stride=2)
        self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
        self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
        self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]

        self.maxpool_5a = nn.MaxPool2d(3, stride=2)
        self.mixed_5b = Mixed_5b()
        self.repeat = nn.Sequential(
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17),
            GhostBlock35(ghost_random_range=ghost_random_range, scale=0.17)
        )
        self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]

        self.mixed_6a = Mixed_6a()
        self.repeat_1 = nn.Sequential(
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10),
            GhostBlock17(ghost_random_range=ghost_random_range, scale=0.10)
        )
        self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]

        self.mixed_7a = Mixed_7a()
        self.repeat_2 = nn.Sequential(
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20),
            GhostBlock8(ghost_random_range=ghost_random_range, scale=0.20)
        )
        self.block8 = GhostBlock8(ghost_random_range=ghost_random_range, no_relu=True)
        self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
        self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]

        self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

    def get_classifier(self):
        return self.classif

    def reset_classifier(self, num_classes, global_pool='avg'):
        self.num_classes = num_classes
        self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

    def forward_features(self, x):
        x = self.conv2d_1a(x)
        x = self.conv2d_2a(x)
        x = self.conv2d_2b(x)
        x = self.maxpool_3a(x)
        x = self.conv2d_3b(x)
        x = self.conv2d_4a(x)
        x = self.maxpool_5a(x)
        x = self.mixed_5b(x)
        x = self.repeat(x)
        x = self.mixed_6a(x)
        x = self.repeat_1(x)
        x = self.mixed_7a(x)
        x = self.repeat_2(x)
        x = self.block8(x)
        x = self.conv2d_7b(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.global_pool(x)
        if self.drop_rate > 0:
            x = F.dropout(x, p=self.drop_rate, training=self.training)
        x = self.classif(x)
        return x


def _create_inception_resnet_v2(variant,  pretrained=True, **kwargs):
    return build_model_with_cfg(
        GhostInceptionResnetV2, variant, pretrained, model_cfg=default_cfgs[variant], **kwargs)


def inception_resnet_v2(**kwargs):
    return _create_inception_resnet_v2('inception_resnet_v2', pretrained=True)


@register_model
def ghost_inception_resnet_v2(pretrained=True, **kwargs):
    r"""InceptionResnetV2 model architecture from the
    `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
    """
    return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)


@register_model
def ens_adv_inception_resnet_v2(pretrained=True, **kwargs):
    r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
    As per https://arxiv.org/abs/1705.07204 and
    https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
    """
    return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)